{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "from dataset import Dataset, Stack_Dataset\n",
    "import models\n",
    "from torch.utils import data\n",
    "from utils import load_model, save_model, Visualizer, Logger, write_result, get_loss_weight, get_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_score(logit, label):\n",
    "    predict_label_list = [list(ii) for ii in logit.topk(5, 1)[1]]\n",
    "    marked_label_list = [list(np.where(ii.numpy()==1)[0]) for ii in label]\n",
    "    \n",
    "    right_label_num = 0\n",
    "    right_label_at_pos_num = [0, 0, 0, 0, 0]\n",
    "    sample_num = 0\n",
    "    all_marked_label_num = 0\n",
    "    for predict_labels, marked_labels in zip(predict_label_list, marked_label_list):\n",
    "        sample_num += 1\n",
    "        marked_label_set = set(marked_labels)\n",
    "        all_marked_label_num += len(marked_label_set)\n",
    "        for pos, label in zip(range(0, min(len(predict_labels), 5)), predict_labels):\n",
    "            if label in marked_label_set:\n",
    "                right_label_num += 1\n",
    "                right_label_at_pos_num[pos] += 1\n",
    "                \n",
    "    precision = 0.0\n",
    "    for pos, right_num in zip(range(0, 5), right_label_at_pos_num):\n",
    "        precision += ((right_num / float(sample_num))) / np.log(2.0 + pos)\n",
    "    recall = float(right_label_num) / all_marked_label_num\n",
    "    score = (precision * recall) / (precision + recall)\n",
    "\n",
    "    return precision, recall, score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_loss_weight(logit, label):\n",
    "    class_num = logit.size(1)\n",
    "    predict_label_list = [list(ii) for ii in logit.topk(5, 1)[1]]\n",
    "    marked_label_list = [list(np.where(ii.numpy()==1)[0]) for ii in label]\n",
    "    sample_per_class = torch.zeros(class_num)\n",
    "    error_per_class = torch.zeros(class_num)\n",
    "    for predict_labels, marked_labels in zip(predict_label_list, marked_label_list):\n",
    "        for true_label in marked_labels:\n",
    "            sample_per_class[true_label] += 1\n",
    "            if true_label not in predict_labels:\n",
    "                error_per_class[true_label] += 1\n",
    "    return error_per_class / sample_per_class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "result_dir = '/home/dyj/'\n",
    "resmat = [result_dir+'TextCNN_2017-07-27#10:15:20_res.pt', \\\n",
    "          result_dir+'TextCNN_2017-07-27#10:32:21_res.pt',\\\n",
    "          result_dir+'RNN_2017-07-27#10:48:05_res.pt',\\\n",
    "          result_dir+'RNN_2017-07-27#10:41:03_res.pt',\\\n",
    "          result_dir+'RCNN_2017-07-27#11:01:07_res.pt',\\\n",
    "          result_dir+'RCNNcha_2017-07-27#16:19:23_res.pt',\\\n",
    "          result_dir+'FastText4_2017-07-28#15:14:47_res.pt',\\\n",
    "          result_dir+'FastText1_2017-07-29#10:31:43_res.pt']\n",
    "label = result_dir+'label.pt'\n",
    "label = torch.load(label)\n",
    "stack_num = len(resmat)\n",
    "\n",
    "train_dataset = Stack_Dataset(resmat=resmat, test=True)\n",
    "train_loader = data.DataLoader(train_dataset, shuffle=False, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load model snapshots/Stack/epoch_10_2017-07-30#23:08:24.params successful!\n"
     ]
    }
   ],
   "source": [
    "Model = getattr(models, 'Stack')\n",
    "opt={'stack_num': stack_num, 'class_num': 1999}\n",
    "model = Model(opt)\n",
    "model = load_model(model, model_dir='snapshots', model_name='Stack')\n",
    "model.cuda()\n",
    "res = torch.Tensor(299997, 1999)\n",
    "for i, batch in enumerate(train_loader, 0):\n",
    "    batch_size = batch[0].size(0)\n",
    "    resmat = batch\n",
    "    resmat = [Variable(ii) for ii in resmat]\n",
    "    resmat = [ii.cuda() for ii in resmat]\n",
    "    logit = model(resmat)\n",
    "    res[i*256:i*256+batch_size] = logit.data.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.4680769910081861, 0.5993500186563444, 0.4255975992557442)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_score(res, label)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
